package com.example.interceptor;

import com.example.dto.Page;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.Map;
import java.util.Properties;

/**
 * 自定义分页插件
 * @Auther: wangxiaodan
 * @Date:2021/4/1 9:49
 * @Description:
 **/
@Intercepts({@Signature(type = StatementHandler.class,method = "prepare",
        args={Connection.class,Integer.class})})
public class PageInterceptor implements Interceptor {
    private static final Logger logger = LoggerFactory.getLogger(PageInterceptor.class);
    //自定义拦截器的核心工作
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //获取StatementHandler,默认的是RoutingStatementHandler
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        //获取StatementHandler的包装类
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        String id = mappedStatement.getId(); //是mapper.xml中namespace+id
        //需要方法名称以Page结尾
        if (id.endsWith("Page")){
            BoundSql boundSql = statementHandler.getBoundSql();
            Map<String, Object> map = (Map<String, Object>) boundSql.getParameterObject();
            //需要在具体查询方法Page参数前添加@Param("page")
            Page page = (Page) map.get("page");
            String sql = boundSql.getSql();
            String countSql = "select count(*) from (" + sql + ") a";
            logger.info("=============count sql:{}",countSql );
            Connection connection = (Connection) invocation.getArgs()[0];
            PreparedStatement preparedStatement = connection.prepareStatement(countSql);
            ParameterHandler parameterHandler = (ParameterHandler) metaObject.getValue("delegate.parameterHandler");
            parameterHandler.setParameters(preparedStatement);
            ResultSet rs = preparedStatement.executeQuery();
            if (rs.next()){
                //设置总条数
                page.setTotalSize(rs.getInt(1));
            }
            String pageSql = sql + " limit " + (page.getStartPage()-1)*page.getPageSize() + ", " + page.getPageSize();
            logger.info("==============final sql:{}",pageSql);
            //将sql再次传给mybatis
            metaObject.setValue("delegate.boundSql.sql", pageSql);
        }
        //把执行流程交给mybatis继续执行
        return invocation.proceed();
    }

    //将自定义的拦截器加入到Mybatis中
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target,this);
    }

    //属性的设置
    @Override
    public void setProperties(Properties properties) {

    }
}
